{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DCGAN相比于GAN做出的改变是：\n",
    "\n",
    "1、G网络中使用转置卷积（transposed convolutional layer）进行上采样，D网络中用加入stride的卷积代替pooling。\n",
    "\n",
    "2、在D和G中均使用batch normalization\n",
    "\n",
    "3、去掉FC层，使网络变为全卷积网络\n",
    "\n",
    "4、G网络中使用ReLU作为激活函数，最后一层使用tanh\n",
    "\n",
    "5、D网络中使用LeakyReLU作为激活函数\n",
    "\n",
    "具体的结构图下所示：\n",
    "\n",
    "Generator的结构：\n",
    "\n",
    "![DCGAN中的G结构图](../images/DCGAN/DCGAN_G.png)\n",
    "\n",
    "Discriminator的结构：\n",
    "\n",
    "![DCGAN中的D结构图](../images/DCGAN/DCGAN_D.png)\n",
    "\n",
    "\n",
    "然后几个心得是：\n",
    "\n",
    "1、GAN真的难训练，learning_rate尽量小点\n",
    "\n",
    "2、自己尝试过的将最后的Conv层换成Linear层，发现效果真的差...\n",
    "\n",
    "3、在没看WGAN之前，以DCGAN来说的话，似乎没有什么指标可以让我们定量的分析训练过程中GAN表现的好坏。\n",
    "\n",
    "4、看loss很容易看到傻眼... 有时候loss真的会小到让你怀疑网络是不是出错了...\n",
    "\n",
    "5、最后也是最重要的一点就是：写完了代码，并且顺利运行了，但是结果不好，先别急着调参，因为很有可能是你的代码写的有问题。\n",
    "\n",
    "数据：https://pan.baidu.com/s/1eSifHcA  提取码：g5qa \n",
    "\n",
    "以上内容参考 https://zhuanlan.zhihu.com/p/24767059 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 结果展示\n",
    "\n",
    "\n",
    "训练10次的结果\n",
    "\n",
    "![DCGAN10次迭代图](./DCGAN_result/fake_images-10.png)\n",
    "\n",
    "训练50次的结果\n",
    "\n",
    "![DCGAN50次迭代图](./DCGAN_result/fake_images-50.png)\n",
    "\n",
    "训练100次的结果\n",
    "\n",
    "![DCGAN100次迭代图](./DCGAN_result/fake_images-100.png)\n",
    "\n",
    "训练200次的结果\n",
    "\n",
    "![DCGAN200次迭代图](./DCGAN_result/fake_images-200.png)\n",
    "\n",
    "训练300次的结果\n",
    "\n",
    "![DCGAN300次迭代图](./DCGAN_result/fake_images-300.png)\n",
    "\n",
    "训练400次的结果\n",
    "\n",
    "![DCGAN400次迭代图](./DCGAN_result/fake_images-400.png)\n",
    "\n",
    "训练500次的结果\n",
    "\n",
    "![DCGAN500次迭代图](./DCGAN_result/fake_images-500.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torchvision import transforms\n",
    "from torch.utils import data\n",
    "import os\n",
    "from PIL import Image\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 构建生成模型\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, noise_dim=100, conv_dim=64, g_img_size=64):\n",
    "        super(Generator, self).__init__()\n",
    "        self.fc = nn.ConvTranspose2d(noise_dim, conv_dim*16, kernel_size=int(g_img_size/16), stride=1, padding=0)\n",
    "        self.deconv2 = nn.ConvTranspose2d(conv_dim*16, conv_dim*8, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(conv_dim * 8)\n",
    "        self.deconv3 = nn.ConvTranspose2d(conv_dim*8, conv_dim*4, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn3 = nn.BatchNorm2d(conv_dim * 4)\n",
    "        self.deconv4 = nn.ConvTranspose2d(conv_dim*4, conv_dim*2, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn4 = nn.BatchNorm2d(conv_dim * 2)\n",
    "        self.deconv5 = nn.ConvTranspose2d(conv_dim*2, 3, kernel_size=4, stride=2, padding=1)\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(x.size(0), x.size(1), 1, 1)\n",
    "        A1 = self.fc(x)\n",
    "        A2 = F.relu(self.bn2(self.deconv2(A1)))\n",
    "        A3 = F.relu(self.bn3(self.deconv3(A2)))\n",
    "        A4 = F.relu(self.bn4(self.deconv4(A3)))\n",
    "        y_hat = F.tanh(self.deconv5(A4))\n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 构建判别模型\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, input_chanel=3, conv_dim=64, d_img_size=64):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(input_chanel, conv_dim * 2, kernel_size=4, stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(conv_dim * 2, conv_dim * 4, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(conv_dim * 4)\n",
    "        self.conv3 = nn.Conv2d(conv_dim * 4, conv_dim * 8, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn3 = nn.BatchNorm2d(conv_dim * 8)\n",
    "        self.conv4 = nn.Conv2d(conv_dim * 8, conv_dim * 16, kernel_size=4, stride=2, padding=1)\n",
    "        self.bn4 = nn.BatchNorm2d(conv_dim * 16)\n",
    "        self.fc = nn.Conv2d(conv_dim * 16, 1, int(d_img_size / 16), 1, 0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        A1 = F.leaky_relu(self.conv1(x))\n",
    "        A2 = F.leaky_relu(self.bn2(self.conv2(A1)), 0.05)\n",
    "        A3 = F.leaky_relu(self.bn3(self.conv3(A2)), 0.05)\n",
    "        A4 = F.leaky_relu(self.bn4(self.conv4(A3)), 0.05)\n",
    "        y_hat = F.sigmoid(self.fc(A4).squeeze())\n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 数据集处理\n",
    "class ImageDataset(data.Dataset):\n",
    "    def __init__(self, path, transform=None):\n",
    "        \"\"\"\n",
    "            path 是存在图像的文件夹。\n",
    "        \"\"\"\n",
    "        self.images = list(map(lambda x: os.path.join(path, x), os.listdir(path)))\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        image_file = self.images[index]\n",
    "        image = Image.open(image_file).convert('RGB')\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        return image\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "\n",
    "\n",
    "def get_dataset(path, img_scale, batch_size):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize(img_scale),\n",
    "        # 也可以使用 Scale\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "    dataset = ImageDataset(path, transform)\n",
    "    data_loader = data.DataLoader(dataset=dataset,\n",
    "                                  batch_size=batch_size,\n",
    "                                  shuffle=True,\n",
    "                                  drop_last=True)\n",
    "    return data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 工具函数:\n",
    "\n",
    "# 生成 噪音 z\n",
    "def gen_noisy(batch_size, noisy_dim):\n",
    "    return torch.randn(batch_size, noisy_dim)\n",
    "\n",
    "\n",
    "# tensor to variable\n",
    "def to_variable(x):\n",
    "    x = Variable(x)\n",
    "    if torch.cuda.is_available():\n",
    "        x = x.cuda()\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 这里可以定义一个 Config 类，用来保存这些超参数\n",
    "class Config(object):\n",
    "    def __init__(self):\n",
    "        self.batch_size = 128\n",
    "        self.image_path = './faces/' # 图像数据存放的文件，如有需要请自行调整。\n",
    "        self.noisy_dim = 100\n",
    "        self.G_lr = 2*1e-6\n",
    "        self.D_lr = 2*1e-6\n",
    "        self.EPOCH = 500\n",
    "        self.img_scale = 64\n",
    "        self.k_step = 5\n",
    "        self.use_gpu = True\n",
    "        self.g_img_size = 64\n",
    "        self.d_img_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 训练阶段 在正式开始训练前，我们先来看看写完的 Generator 和 Discriminator\n",
    "config = Config()\n",
    "\n",
    "G = Generator(noise_dim=config.noisy_dim, g_img_size=config.g_img_size)\n",
    "if config.use_gpu:\n",
    "    G = G.cuda()\n",
    "    \n",
    "D = Discriminator(d_img_size=config.d_img_size)\n",
    "if config.use_gpu:\n",
    "    D = D.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generator的结构是：\n",
      "Generator (\n",
      "  (fc): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))\n",
      "  (deconv2): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (deconv3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (deconv4): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (deconv5): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(\"Generator的结构是：\")\n",
    "print(G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Discriminator的结构是：\n",
      "Discriminator (\n",
      "  (conv1): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (conv4): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
      "  (bn4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)\n",
      "  (fc): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(\"Discriminator的结构是：\")\n",
    "print(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 正式开始训练的阶段：\n",
    "\n",
    "train_data_loader = get_dataset(config.image_path, config.img_scale, config.batch_size)\n",
    "loss_fn = torch.nn.BCELoss()\n",
    "g_optimizer = optim.Adam(G.parameters(), lr=config.G_lr)\n",
    "d_optimizer = optim.Adam(D.parameters(), lr=config.D_lr)\n",
    "\n",
    "for epoch in range(config.EPOCH):\n",
    "    g_total_loss = torch.FloatTensor([0])\n",
    "    d_total_loss = torch.FloatTensor([0])\n",
    "    count = 0\n",
    "    \n",
    "    for i, data in enumerate(train_data_loader):\n",
    "        count += 1\n",
    "        true_inputs = data\n",
    "        images = to_variable(true_inputs)\n",
    "        batch_size = images.size(0)\n",
    "\n",
    "        z = to_variable(gen_noisy(batch_size, config.noisy_dim))\n",
    "\n",
    "        real_labels = to_variable(torch.ones(batch_size))\n",
    "        fake_labels = to_variable(torch.zeros(batch_size))\n",
    "\n",
    "\n",
    "        ###          train D           ###\n",
    "        outputs = D(images)\n",
    "        d_loss_real = loss_fn(outputs, real_labels)\n",
    "        real_score = outputs\n",
    "\n",
    "        fake_images = G(z)\n",
    "        outputs = D(fake_images)\n",
    "        d_loss_fake = loss_fn(outputs, fake_labels)\n",
    "        fake_score = outputs\n",
    "\n",
    "        d_loss = d_loss_real + d_loss_fake\n",
    "        # count total loss\n",
    "        d_total_loss += d_loss.data[0]\n",
    "        D.zero_grad()\n",
    "        d_loss.backward()\n",
    "        d_optimizer.step()\n",
    "\n",
    "\n",
    "        ###          train G           ###\n",
    "        z = to_variable(gen_noisy(batch_size, config.noisy_dim))\n",
    "        fake_images = G(z)\n",
    "        outputs = D(fake_images)\n",
    "        g_loss = loss_fn(outputs, real_labels)\n",
    "        g_total_loss += g_loss.data[0]\n",
    "        G.zero_grad()\n",
    "        g_loss.backward()\n",
    "        g_optimizer.step()\n",
    "\n",
    "        if (i + 1) % 150 == 0:\n",
    "            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '\n",
    "                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f'\n",
    "                  % (epoch+1, config.EPOCH, i + 1, count, d_loss.data[0], g_loss.data[0],\n",
    "                     real_score.data.mean(), fake_score.data.mean()))\n",
    "    print('Epoch [%d/%d]'% (epoch, config.EPOCH))\n",
    "    print('D 的 total loss', d_total_loss / count)\n",
    "    print('G 的 total loss', g_total_loss / count)\n",
    "    # Save real images\n",
    "\n",
    "    fake_images = fake_images.view(fake_images.size(0), 3, 64, 64)\n",
    "    # 注意 这里需要创建 test_DCGAN 文件夹，否则会报错\n",
    "    save_image(fake_images.data, './test_DCGAN/fake_images-%d.png' % (epoch + 1))\n"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
